{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "google",
   "metadata": {},
   "source": [
    "##### Copyright 2023 Google LLC."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "apache",
   "metadata": {},
   "source": [
    "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "you may not use this file except in compliance with the License.\n",
    "You may obtain a copy of the License at\n",
    "\n",
    "    http://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software\n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "See the License for the specific language governing permissions and\n",
    "limitations under the License.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "basename",
   "metadata": {},
   "source": [
    "# shift_scheduling_sat"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "link",
   "metadata": {},
   "source": [
    "<table align=\"left\">\n",
    "<td>\n",
    "<a href=\"https://colab.research.google.com/github/google/or-tools/blob/main/examples/notebook/examples/shift_scheduling_sat.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\"/>Run in Google Colab</a>\n",
    "</td>\n",
    "<td>\n",
    "<a href=\"https://github.com/google/or-tools/blob/main/examples/python/shift_scheduling_sat.py\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\"/>View source on GitHub</a>\n",
    "</td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "doc",
   "metadata": {},
   "source": [
    "First, you must install [ortools](https://pypi.org/project/ortools/) package in this colab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "install",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install ortools"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "description",
   "metadata": {},
   "source": [
    "\n",
    "Creates a shift scheduling problem and solves it.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "code",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ortools.sat.colab import flags\n",
    "from ortools.sat.python import cp_model\n",
    "from google.protobuf import text_format\n",
    "\n",
    "_OUTPUT_PROTO = flags.define_string(\n",
    "    \"output_proto\", \"\", \"Output file to write the cp_model proto to.\"\n",
    ")\n",
    "_PARAMS = flags.define_string(\n",
    "    \"params\", \"max_time_in_seconds:10.0\", \"Sat solver parameters.\"\n",
    ")\n",
    "\n",
    "\n",
    "def negated_bounded_span(works, start, length):\n",
    "    \"\"\"Filters an isolated sub-sequence of variables assined to True.\n",
    "\n",
    "    Extract the span of Boolean variables [start, start + length), negate them,\n",
    "    and if there is variables to the left/right of this span, surround the span by\n",
    "    them in non negated form.\n",
    "\n",
    "    Args:\n",
    "      works: a list of variables to extract the span from.\n",
    "      start: the start to the span.\n",
    "      length: the length of the span.\n",
    "\n",
    "    Returns:\n",
    "      a list of variables which conjunction will be false if the sub-list is\n",
    "      assigned to True, and correctly bounded by variables assigned to False,\n",
    "      or by the start or end of works.\n",
    "    \"\"\"\n",
    "    sequence = []\n",
    "    # Left border (start of works, or works[start - 1])\n",
    "    if start > 0:\n",
    "        sequence.append(works[start - 1])\n",
    "    for i in range(length):\n",
    "        sequence.append(works[start + i].Not())\n",
    "    # Right border (end of works or works[start + length])\n",
    "    if start + length < len(works):\n",
    "        sequence.append(works[start + length])\n",
    "    return sequence\n",
    "\n",
    "\n",
    "def add_soft_sequence_constraint(\n",
    "    model, works, hard_min, soft_min, min_cost, soft_max, hard_max, max_cost, prefix\n",
    "):\n",
    "    \"\"\"Sequence constraint on true variables with soft and hard bounds.\n",
    "\n",
    "    This constraint look at every maximal contiguous sequence of variables\n",
    "    assigned to true. If forbids sequence of length < hard_min or > hard_max.\n",
    "    Then it creates penalty terms if the length is < soft_min or > soft_max.\n",
    "\n",
    "    Args:\n",
    "      model: the sequence constraint is built on this model.\n",
    "      works: a list of Boolean variables.\n",
    "      hard_min: any sequence of true variables must have a length of at least\n",
    "        hard_min.\n",
    "      soft_min: any sequence should have a length of at least soft_min, or a\n",
    "        linear penalty on the delta will be added to the objective.\n",
    "      min_cost: the coefficient of the linear penalty if the length is less than\n",
    "        soft_min.\n",
    "      soft_max: any sequence should have a length of at most soft_max, or a linear\n",
    "        penalty on the delta will be added to the objective.\n",
    "      hard_max: any sequence of true variables must have a length of at most\n",
    "        hard_max.\n",
    "      max_cost: the coefficient of the linear penalty if the length is more than\n",
    "        soft_max.\n",
    "      prefix: a base name for penalty literals.\n",
    "\n",
    "    Returns:\n",
    "      a tuple (variables_list, coefficient_list) containing the different\n",
    "      penalties created by the sequence constraint.\n",
    "    \"\"\"\n",
    "    cost_literals = []\n",
    "    cost_coefficients = []\n",
    "\n",
    "    # Forbid sequences that are too short.\n",
    "    for length in range(1, hard_min):\n",
    "        for start in range(len(works) - length + 1):\n",
    "            model.AddBoolOr(negated_bounded_span(works, start, length))\n",
    "\n",
    "    # Penalize sequences that are below the soft limit.\n",
    "    if min_cost > 0:\n",
    "        for length in range(hard_min, soft_min):\n",
    "            for start in range(len(works) - length + 1):\n",
    "                span = negated_bounded_span(works, start, length)\n",
    "                name = \": under_span(start=%i, length=%i)\" % (start, length)\n",
    "                lit = model.NewBoolVar(prefix + name)\n",
    "                span.append(lit)\n",
    "                model.AddBoolOr(span)\n",
    "                cost_literals.append(lit)\n",
    "                # We filter exactly the sequence with a short length.\n",
    "                # The penalty is proportional to the delta with soft_min.\n",
    "                cost_coefficients.append(min_cost * (soft_min - length))\n",
    "\n",
    "    # Penalize sequences that are above the soft limit.\n",
    "    if max_cost > 0:\n",
    "        for length in range(soft_max + 1, hard_max + 1):\n",
    "            for start in range(len(works) - length + 1):\n",
    "                span = negated_bounded_span(works, start, length)\n",
    "                name = \": over_span(start=%i, length=%i)\" % (start, length)\n",
    "                lit = model.NewBoolVar(prefix + name)\n",
    "                span.append(lit)\n",
    "                model.AddBoolOr(span)\n",
    "                cost_literals.append(lit)\n",
    "                # Cost paid is max_cost * excess length.\n",
    "                cost_coefficients.append(max_cost * (length - soft_max))\n",
    "\n",
    "    # Just forbid any sequence of true variables with length hard_max + 1\n",
    "    for start in range(len(works) - hard_max):\n",
    "        model.AddBoolOr([works[i].Not() for i in range(start, start + hard_max + 1)])\n",
    "    return cost_literals, cost_coefficients\n",
    "\n",
    "\n",
    "def add_soft_sum_constraint(\n",
    "    model, works, hard_min, soft_min, min_cost, soft_max, hard_max, max_cost, prefix\n",
    "):\n",
    "    \"\"\"Sum constraint with soft and hard bounds.\n",
    "\n",
    "    This constraint counts the variables assigned to true from works.\n",
    "    If forbids sum < hard_min or > hard_max.\n",
    "    Then it creates penalty terms if the sum is < soft_min or > soft_max.\n",
    "\n",
    "    Args:\n",
    "      model: the sequence constraint is built on this model.\n",
    "      works: a list of Boolean variables.\n",
    "      hard_min: any sequence of true variables must have a sum of at least\n",
    "        hard_min.\n",
    "      soft_min: any sequence should have a sum of at least soft_min, or a linear\n",
    "        penalty on the delta will be added to the objective.\n",
    "      min_cost: the coefficient of the linear penalty if the sum is less than\n",
    "        soft_min.\n",
    "      soft_max: any sequence should have a sum of at most soft_max, or a linear\n",
    "        penalty on the delta will be added to the objective.\n",
    "      hard_max: any sequence of true variables must have a sum of at most\n",
    "        hard_max.\n",
    "      max_cost: the coefficient of the linear penalty if the sum is more than\n",
    "        soft_max.\n",
    "      prefix: a base name for penalty variables.\n",
    "\n",
    "    Returns:\n",
    "      a tuple (variables_list, coefficient_list) containing the different\n",
    "      penalties created by the sequence constraint.\n",
    "    \"\"\"\n",
    "    cost_variables = []\n",
    "    cost_coefficients = []\n",
    "    sum_var = model.NewIntVar(hard_min, hard_max, \"\")\n",
    "    # This adds the hard constraints on the sum.\n",
    "    model.Add(sum_var == sum(works))\n",
    "\n",
    "    # Penalize sums below the soft_min target.\n",
    "    if soft_min > hard_min and min_cost > 0:\n",
    "        delta = model.NewIntVar(-len(works), len(works), \"\")\n",
    "        model.Add(delta == soft_min - sum_var)\n",
    "        # TODO(user): Compare efficiency with only excess >= soft_min - sum_var.\n",
    "        excess = model.NewIntVar(0, 7, prefix + \": under_sum\")\n",
    "        model.AddMaxEquality(excess, [delta, 0])\n",
    "        cost_variables.append(excess)\n",
    "        cost_coefficients.append(min_cost)\n",
    "\n",
    "    # Penalize sums above the soft_max target.\n",
    "    if soft_max < hard_max and max_cost > 0:\n",
    "        delta = model.NewIntVar(-7, 7, \"\")\n",
    "        model.Add(delta == sum_var - soft_max)\n",
    "        excess = model.NewIntVar(0, 7, prefix + \": over_sum\")\n",
    "        model.AddMaxEquality(excess, [delta, 0])\n",
    "        cost_variables.append(excess)\n",
    "        cost_coefficients.append(max_cost)\n",
    "\n",
    "    return cost_variables, cost_coefficients\n",
    "\n",
    "\n",
    "def solve_shift_scheduling(params, output_proto):\n",
    "    \"\"\"Solves the shift scheduling problem.\"\"\"\n",
    "    # Data\n",
    "    num_employees = 8\n",
    "    num_weeks = 3\n",
    "    shifts = [\"O\", \"M\", \"A\", \"N\"]\n",
    "\n",
    "    # Fixed assignment: (employee, shift, day).\n",
    "    # This fixes the first 2 days of the schedule.\n",
    "    fixed_assignments = [\n",
    "        (0, 0, 0),\n",
    "        (1, 0, 0),\n",
    "        (2, 1, 0),\n",
    "        (3, 1, 0),\n",
    "        (4, 2, 0),\n",
    "        (5, 2, 0),\n",
    "        (6, 2, 3),\n",
    "        (7, 3, 0),\n",
    "        (0, 1, 1),\n",
    "        (1, 1, 1),\n",
    "        (2, 2, 1),\n",
    "        (3, 2, 1),\n",
    "        (4, 2, 1),\n",
    "        (5, 0, 1),\n",
    "        (6, 0, 1),\n",
    "        (7, 3, 1),\n",
    "    ]\n",
    "\n",
    "    # Request: (employee, shift, day, weight)\n",
    "    # A negative weight indicates that the employee desire this assignment.\n",
    "    requests = [\n",
    "        # Employee 3 does not want to work on the first Saturday (negative weight\n",
    "        # for the Off shift).\n",
    "        (3, 0, 5, -2),\n",
    "        # Employee 5 wants a night shift on the second Thursday (negative weight).\n",
    "        (5, 3, 10, -2),\n",
    "        # Employee 2 does not want a night shift on the first Friday (positive\n",
    "        # weight).\n",
    "        (2, 3, 4, 4),\n",
    "    ]\n",
    "\n",
    "    # Shift constraints on continuous sequence :\n",
    "    #     (shift, hard_min, soft_min, min_penalty,\n",
    "    #             soft_max, hard_max, max_penalty)\n",
    "    shift_constraints = [\n",
    "        # One or two consecutive days of rest, this is a hard constraint.\n",
    "        (0, 1, 1, 0, 2, 2, 0),\n",
    "        # between 2 and 3 consecutive days of night shifts, 1 and 4 are\n",
    "        # possible but penalized.\n",
    "        (3, 1, 2, 20, 3, 4, 5),\n",
    "    ]\n",
    "\n",
    "    # Weekly sum constraints on shifts days:\n",
    "    #     (shift, hard_min, soft_min, min_penalty,\n",
    "    #             soft_max, hard_max, max_penalty)\n",
    "    weekly_sum_constraints = [\n",
    "        # Constraints on rests per week.\n",
    "        (0, 1, 2, 7, 2, 3, 4),\n",
    "        # At least 1 night shift per week (penalized). At most 4 (hard).\n",
    "        (3, 0, 1, 3, 4, 4, 0),\n",
    "    ]\n",
    "\n",
    "    # Penalized transitions:\n",
    "    #     (previous_shift, next_shift, penalty (0 means forbidden))\n",
    "    penalized_transitions = [\n",
    "        # Afternoon to night has a penalty of 4.\n",
    "        (2, 3, 4),\n",
    "        # Night to morning is forbidden.\n",
    "        (3, 1, 0),\n",
    "    ]\n",
    "\n",
    "    # daily demands for work shifts (morning, afternon, night) for each day\n",
    "    # of the week starting on Monday.\n",
    "    weekly_cover_demands = [\n",
    "        (2, 3, 1),  # Monday\n",
    "        (2, 3, 1),  # Tuesday\n",
    "        (2, 2, 2),  # Wednesday\n",
    "        (2, 3, 1),  # Thursday\n",
    "        (2, 2, 2),  # Friday\n",
    "        (1, 2, 3),  # Saturday\n",
    "        (1, 3, 1),  # Sunday\n",
    "    ]\n",
    "\n",
    "    # Penalty for exceeding the cover constraint per shift type.\n",
    "    excess_cover_penalties = (2, 2, 5)\n",
    "\n",
    "    num_days = num_weeks * 7\n",
    "    num_shifts = len(shifts)\n",
    "\n",
    "    model = cp_model.CpModel()\n",
    "\n",
    "    work = {}\n",
    "    for e in range(num_employees):\n",
    "        for s in range(num_shifts):\n",
    "            for d in range(num_days):\n",
    "                work[e, s, d] = model.NewBoolVar(\"work%i_%i_%i\" % (e, s, d))\n",
    "\n",
    "    # Linear terms of the objective in a minimization context.\n",
    "    obj_int_vars = []\n",
    "    obj_int_coeffs = []\n",
    "    obj_bool_vars = []\n",
    "    obj_bool_coeffs = []\n",
    "\n",
    "    # Exactly one shift per day.\n",
    "    for e in range(num_employees):\n",
    "        for d in range(num_days):\n",
    "            model.AddExactlyOne(work[e, s, d] for s in range(num_shifts))\n",
    "\n",
    "    # Fixed assignments.\n",
    "    for e, s, d in fixed_assignments:\n",
    "        model.Add(work[e, s, d] == 1)\n",
    "\n",
    "    # Employee requests\n",
    "    for e, s, d, w in requests:\n",
    "        obj_bool_vars.append(work[e, s, d])\n",
    "        obj_bool_coeffs.append(w)\n",
    "\n",
    "    # Shift constraints\n",
    "    for ct in shift_constraints:\n",
    "        shift, hard_min, soft_min, min_cost, soft_max, hard_max, max_cost = ct\n",
    "        for e in range(num_employees):\n",
    "            works = [work[e, shift, d] for d in range(num_days)]\n",
    "            variables, coeffs = add_soft_sequence_constraint(\n",
    "                model,\n",
    "                works,\n",
    "                hard_min,\n",
    "                soft_min,\n",
    "                min_cost,\n",
    "                soft_max,\n",
    "                hard_max,\n",
    "                max_cost,\n",
    "                \"shift_constraint(employee %i, shift %i)\" % (e, shift),\n",
    "            )\n",
    "            obj_bool_vars.extend(variables)\n",
    "            obj_bool_coeffs.extend(coeffs)\n",
    "\n",
    "    # Weekly sum constraints\n",
    "    for ct in weekly_sum_constraints:\n",
    "        shift, hard_min, soft_min, min_cost, soft_max, hard_max, max_cost = ct\n",
    "        for e in range(num_employees):\n",
    "            for w in range(num_weeks):\n",
    "                works = [work[e, shift, d + w * 7] for d in range(7)]\n",
    "                variables, coeffs = add_soft_sum_constraint(\n",
    "                    model,\n",
    "                    works,\n",
    "                    hard_min,\n",
    "                    soft_min,\n",
    "                    min_cost,\n",
    "                    soft_max,\n",
    "                    hard_max,\n",
    "                    max_cost,\n",
    "                    \"weekly_sum_constraint(employee %i, shift %i, week %i)\"\n",
    "                    % (e, shift, w),\n",
    "                )\n",
    "                obj_int_vars.extend(variables)\n",
    "                obj_int_coeffs.extend(coeffs)\n",
    "\n",
    "    # Penalized transitions\n",
    "    for previous_shift, next_shift, cost in penalized_transitions:\n",
    "        for e in range(num_employees):\n",
    "            for d in range(num_days - 1):\n",
    "                transition = [\n",
    "                    work[e, previous_shift, d].Not(),\n",
    "                    work[e, next_shift, d + 1].Not(),\n",
    "                ]\n",
    "                if cost == 0:\n",
    "                    model.AddBoolOr(transition)\n",
    "                else:\n",
    "                    trans_var = model.NewBoolVar(\n",
    "                        \"transition (employee=%i, day=%i)\" % (e, d)\n",
    "                    )\n",
    "                    transition.append(trans_var)\n",
    "                    model.AddBoolOr(transition)\n",
    "                    obj_bool_vars.append(trans_var)\n",
    "                    obj_bool_coeffs.append(cost)\n",
    "\n",
    "    # Cover constraints\n",
    "    for s in range(1, num_shifts):\n",
    "        for w in range(num_weeks):\n",
    "            for d in range(7):\n",
    "                works = [work[e, s, w * 7 + d] for e in range(num_employees)]\n",
    "                # Ignore Off shift.\n",
    "                min_demand = weekly_cover_demands[d][s - 1]\n",
    "                worked = model.NewIntVar(min_demand, num_employees, \"\")\n",
    "                model.Add(worked == sum(works))\n",
    "                over_penalty = excess_cover_penalties[s - 1]\n",
    "                if over_penalty > 0:\n",
    "                    name = \"excess_demand(shift=%i, week=%i, day=%i)\" % (s, w, d)\n",
    "                    excess = model.NewIntVar(0, num_employees - min_demand, name)\n",
    "                    model.Add(excess == worked - min_demand)\n",
    "                    obj_int_vars.append(excess)\n",
    "                    obj_int_coeffs.append(over_penalty)\n",
    "\n",
    "    # Objective\n",
    "    model.Minimize(\n",
    "        sum(obj_bool_vars[i] * obj_bool_coeffs[i] for i in range(len(obj_bool_vars)))\n",
    "        + sum(obj_int_vars[i] * obj_int_coeffs[i] for i in range(len(obj_int_vars)))\n",
    "    )\n",
    "\n",
    "    if output_proto:\n",
    "        print(\"Writing proto to %s\" % output_proto)\n",
    "        with open(output_proto, \"w\") as text_file:\n",
    "            text_file.write(str(model))\n",
    "\n",
    "    # Solve the model.\n",
    "    solver = cp_model.CpSolver()\n",
    "    if params:\n",
    "        text_format.Parse(params, solver.parameters)\n",
    "    solution_printer = cp_model.ObjectiveSolutionPrinter()\n",
    "    status = solver.Solve(model, solution_printer)\n",
    "\n",
    "    # Print solution.\n",
    "    if status == cp_model.OPTIMAL or status == cp_model.FEASIBLE:\n",
    "        print()\n",
    "        header = \"          \"\n",
    "        for w in range(num_weeks):\n",
    "            header += \"M T W T F S S \"\n",
    "        print(header)\n",
    "        for e in range(num_employees):\n",
    "            schedule = \"\"\n",
    "            for d in range(num_days):\n",
    "                for s in range(num_shifts):\n",
    "                    if solver.BooleanValue(work[e, s, d]):\n",
    "                        schedule += shifts[s] + \" \"\n",
    "            print(\"worker %i: %s\" % (e, schedule))\n",
    "        print()\n",
    "        print(\"Penalties:\")\n",
    "        for i, var in enumerate(obj_bool_vars):\n",
    "            if solver.BooleanValue(var):\n",
    "                penalty = obj_bool_coeffs[i]\n",
    "                if penalty > 0:\n",
    "                    print(\"  %s violated, penalty=%i\" % (var.Name(), penalty))\n",
    "                else:\n",
    "                    print(\"  %s fulfilled, gain=%i\" % (var.Name(), -penalty))\n",
    "\n",
    "        for i, var in enumerate(obj_int_vars):\n",
    "            if solver.Value(var) > 0:\n",
    "                print(\n",
    "                    \"  %s violated by %i, linear penalty=%i\"\n",
    "                    % (var.Name(), solver.Value(var), obj_int_coeffs[i])\n",
    "                )\n",
    "\n",
    "    print()\n",
    "    print(\"Statistics\")\n",
    "    print(\"  - status          : %s\" % solver.StatusName(status))\n",
    "    print(\"  - conflicts       : %i\" % solver.NumConflicts())\n",
    "    print(\"  - branches        : %i\" % solver.NumBranches())\n",
    "    print(\"  - wall time       : %f s\" % solver.WallTime())\n",
    "\n",
    "\n",
    "def main(_):\n",
    "    solve_shift_scheduling(_PARAMS.value, _OUTPUT_PROTO.value)\n",
    "\n",
    "\n",
    "main()\n",
    "\n"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 5
}
